{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OsFaZscKSPvo"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NrHRIznKp3nS"
      },
      "source": [
        "# Run ML inference by using vLLM on GPUs\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_vllm.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_vllm.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H0ZFs9rDvtJm"
      },
      "source": [
        "[vLLM](https://github.com/vllm-project/vllm) is a fast and user-friendly library for LLM inference and serving. vLLM optimizes LLM inference with mechanisms like PagedAttention for memory management and continuous batching for increasing throughput. For popular models, vLLM has been shown to increase throughput by a multiple of 2 to 4. With Apache Beam, you can serve models with vLLM and scale that serving with just a few lines of code.\n",
        "\n",
        "This notebook demonstrates how to run machine learning inference by using vLLM and GPUs in three ways:\n",
        "\n",
        "* locally without Apache Beam\n",
        "* locally with the Apache Beam local runner\n",
        "* remotely with the Dataflow runner\n",
        "\n",
        "It also shows how to swap in a different model without modifying your pipeline structure by changing the configuration."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6x41tnbTvQM1"
      },
      "source": [
        "## Requirements\n",
        "\n",
        "This notebook assumes that a GPU is enabled in Colab. If this setting isn't enabled, the locally executed sections of this notebook might not work. To enable a GPU, in the Colab menu, click **Runtime** > **Change runtime type**. For **Hardware accelerator**, choose a GPU accelerator. If you can't access a GPU in Colab, you can run the Dataflow section of this notebook.\n",
        "\n",
        "To run the Dataflow section, you need access to the following resources:\n",
        "\n",
        "- a computer with Docker installed\n",
        "- a [Google Cloud](https://cloud.google.com/) account"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8PSjyDIavRcn"
      },
      "source": [
        "## Install dependencies\n",
        "\n",
        "Before creating your pipeline, download and install the dependencies required to develop with Apache Beam and vLLM. vLLM is supported in Apache Beam versions 2.60.0 and later."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "irCKNe42p22r"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "!pip install openai>=1.52.2\n",
        "!pip install vllm>=0.6.3\n",
        "!pip install triton>=3.1.0\n",
        "!pip install apache-beam[gcp]==2.61.0\n",
        "!pip install nest_asyncio # only needed in colab\n",
        "!pip check"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Colab only: allow nested asyncio\n",
        "\n",
        "The vLLM model handler logic below uses asyncio to feed vLLM records. This only works if we are not already in an asyncio event loop. Most of the time, this is fine, but colab already operates in an event loop. To work around this, we can use nest_asyncio to make things work smoothly in colab. Do not include this step outside of colab."
      ],
      "metadata": {
        "id": "3xz8zuA7vcS3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# This should not be necessary outside of colab.\n",
        "import nest_asyncio\n",
        "nest_asyncio.apply()\n"
      ],
      "metadata": {
        "id": "sUqjOzw3wpI3"
      },
      "execution_count": null,
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sUqjOzw3wpI4"
      },
      "source": [
        "## Run locally without Apache Beam\n",
        "\n",
        "In this section, you run a vLLM server without using Apache Beam. Use the `facebook/opt-125m` model. This model is small enough to fit in Colab memory and doesn't require any extra authentication.\n",
        "\n",
        "First, start the vLLM server. This step might take a minute or two, because the model needs to download before vLLM starts running inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GbJGzINNt5sG"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "! python -m vllm.entrypoints.openai.api_server --model facebook/opt-125m"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n35LXTS3uzIC"
      },
      "source": [
        "Next, while the vLLM server is running, open a separate terminal to communicate with the vLLM serving process. To open a terminal in Colab, in the sidebar, click **Terminal**. In the terminal, run the following commands.\n",
        "\n",
        "```\n",
        "pip install openai\n",
        "python\n",
        "\n",
        "from openai import OpenAI\n",
        "\n",
        "# Modify OpenAI's API key and API base to use vLLM's API server.\n",
        "openai_api_key = \"EMPTY\"\n",
        "openai_api_base = \"http://localhost:8000/v1\"\n",
        "client = OpenAI(\n",
        "    api_key=openai_api_key,\n",
        "    base_url=openai_api_base,\n",
        ")\n",
        "completion = client.completions.create(model=\"facebook/opt-125m\",\n",
        "                                      prompt=\"San Francisco is a\")\n",
        "print(\"Completion result:\", completion)\n",
        "```\n",
        "\n",
        "This code runs against the server running in the cell. You can experiment with different prompts."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hbxi83BfwbBa"
      },
      "source": [
        "## Run locally with Apache Beam\n",
        "\n",
        "In this section, you set up an Apache Beam pipeline to run a job with an embedded vLLM instance.\n",
        "\n",
        "First, define the `VllmCompletionsModelHandler` object. This configuration object gives Apache Beam the information that it needs to create a dedicated vLLM process in the middle of the pipeline. Apache Beam then provides examples to the pipeline. No additional code is needed."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sUqjOzw3wpI4"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "from apache_beam.ml.inference.base import RunInference\n",
        "from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler\n",
        "from apache_beam.ml.inference.base import PredictionResult\n",
        "import apache_beam as beam\n",
        "\n",
        "model_handler = VLLMCompletionsModelHandler('facebook/opt-125m')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N06lXRKRxCz5"
      },
      "source": [
        "Next, define examples to run inference against, and define a helper function to print out the inference results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3a1PznmtxNR_"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "class FormatOutput(beam.DoFn):\n",
        "  def process(self, element, *args, **kwargs):\n",
        "    yield \"Input: {input}, Output: {output}\".format(input=element.example, output=element.inference)\n",
        "\n",
        "prompts = [\n",
        "    \"Hello, my name is\",\n",
        "    \"The president of the United States is\",\n",
        "    \"The capital of France is\",\n",
        "    \"The future of AI is\",\n",
        "    \"Emperor penguins are\",\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Njl0QfrLxQ0m"
      },
      "source": [
        "Finally, run the pipeline.\n",
        "\n",
        "This step might take a minute or two, because the model needs to download before Apache Beam can start running inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9yXbzV0ZmZcJ"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "with beam.Pipeline() as p:\n",
        "  _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n",
        "         | RunInference(model_handler) # Send the prompts to the model and get responses.\n",
        "         | beam.ParDo(FormatOutput()) # Format the output.\n",
        "         | beam.Map(print) # Print the formatted output.\n",
        "  )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jv7be6Pk9Hlx"
      },
      "source": [
        "## Run remotely on Dataflow\n",
        "\n",
        "After you validate that the pipeline can run against a vLLM locally, you can productionalize the workflow on a remote runner. This notebook runs the pipeline on the Dataflow runner."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J1LMrl1Yy6QB"
      },
      "source": [
        "### Build a Docker image\n",
        "\n",
        "To run a pipeline with vLLM on Dataflow, you must create a Docker image that contains your dependencies and is compatible with a GPU runtime. For more information about building GPU compatible Dataflow containers, see [Build a custom container image](https://cloud.google.com/dataflow/docs/gpu/use-gpus#custom-container) in the Datafow documentation.\n",
        "\n",
        "First, define and save your Dockerfile. This file uses an Nvidia GPU-compatible base image. In the Dockerfile, install the Python dependencies needed to run the job.\n",
        "\n",
        "Before proceeding, make sure that your configuration meets the following requirements:\n",
        "\n",
        "- The Python version in the following cell matches the Python version defined in the Dockerfile.\n",
        "- The Apache Beam version defined in your dependencies matches the Apache Beam version defined in the Dockerfile."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jCQ6-D55gqfl"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "!python --version"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7QyNq_gygHLO"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "cell_str='''\n",
        "FROM nvidia/cuda:12.4.1-devel-ubuntu22.04\n",
        "\n",
        "RUN apt update\n",
        "RUN apt install software-properties-common -y\n",
        "RUN add-apt-repository ppa:deadsnakes/ppa\n",
        "RUN apt update\n",
        "RUN apt-get update\n",
        "\n",
        "ARG DEBIAN_FRONTEND=noninteractive\n",
        "\n",
        "RUN apt install python3.10-full -y\n",
        "# RUN apt install python3.10-venv -y\n",
        "# RUN apt install python3.10-dev -y\n",
        "RUN rm /usr/bin/python3\n",
        "RUN ln -s python3.10 /usr/bin/python3\n",
        "RUN python3 --version\n",
        "RUN apt-get install -y curl\n",
        "RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 && pip install --upgrade pip\n",
        "\n",
        "# Copy the Apache Beam worker dependencies from the Beam Python 3.10 SDK image.\n",
        "COPY --from=apache/beam_python3.10_sdk:2.61.0 /opt/apache/beam /opt/apache/beam\n",
        "\n",
        "RUN pip install --no-cache-dir -vvv apache-beam[gcp]==2.61.0\n",
        "RUN pip install openai>=1.52.2 vllm>=0.6.3 triton>=3.1.0\n",
        "\n",
        "RUN apt install libcairo2-dev pkg-config python3-dev -y\n",
        "RUN pip install pycairo\n",
        "\n",
        "# Set the entrypoint to Apache Beam SDK worker launcher.\n",
        "ENTRYPOINT [ \"/opt/apache/beam/boot\" ]\n",
        "'''\n",
        "\n",
        "with open('VllmDockerfile', 'w') as f:\n",
        "  f.write(cell_str)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zWma0YetiEn5"
      },
      "source": [
        "After you save the Dockerfile, build and push your Docker image. Because Docker is not accessible from Colab, you need to complete this step in a separate environment.\n",
        "\n",
        "1.   In the sidebar, click **Files** to open the **Files** pane.\n",
        "2.   In an environment with Docker installed, download the file **VllmDockerfile** file to an empty folder.\n",
        "3.   Run the following commands. Replace `<REPOSITORY_NAME>:<TAG>` with a valid [Artifact Registry](https://cloud.google.com/artifact-registry/docs/overview) repository and tag.\n",
        "\n",
        "  ```\n",
        "  docker build -t \"<REPOSITORY_NAME>:<TAG>\" -f VllmDockerfile ./\n",
        "  docker image push \"<REPOSITORY_NAME>:<TAG>\"\n",
        "  ```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NjZyRjte0g0Q"
      },
      "source": [
        "### Define and run the pipeline\n",
        "\n",
        "When you have a working Docker image, define and run your pipeline.\n",
        "\n",
        "First, define the pipeline options that you want to use to launch the Dataflow job. Before running the next cell, replace the following variables:\n",
        "\n",
        "- `<BUCKET_NAME>`: the name of a valid [Google Cloud Storage](https://cloud.google.com/storage?e=48754805&hl=en) bucket. Don't include a `gs://` prefix or trailing slashes.\n",
        "- `<REPOSITORY_NAME>`: the name of the Google Artifact Registry repository that you used in the previous step. \n",
        "- `<IMAGE_TAG>`: image tag used in the previous step. Prefer a versioned tag or SHA instead of :latest tag or mutable tags.\n",
        "- `<PROJECT_ID>`: the name of the Google Cloud project that you created your bucket and Artifact Registry repository in.\n",
        "\n",
        "This workflow uses the following Dataflow service option: `worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx`. When you use this service option, Dataflow to installs a T4 GPU that uses a `5xx` series Nvidia driver on each worker machine. The 5xx driver is required to run vLLM jobs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kXy9FRYVCSjq"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "\n",
        "from apache_beam.options.pipeline_options import GoogleCloudOptions\n",
        "from apache_beam.options.pipeline_options import PipelineOptions\n",
        "from apache_beam.options.pipeline_options import SetupOptions\n",
        "from apache_beam.options.pipeline_options import StandardOptions\n",
        "from apache_beam.options.pipeline_options import WorkerOptions\n",
        "\n",
        "\n",
        "options = PipelineOptions()\n",
        "\n",
        "# Replace with your bucket name.\n",
        "BUCKET_NAME = '<BUCKET_NAME>' # @param {type:'string'}\n",
        "# Replace with the image repository and tag from the previous step.\n",
        "CONTAINER_IMAGE = '<REPOSITORY_NAME>:<TAG>'  # @param {type:'string'}\n",
        "# Replace with your GCP project\n",
        "PROJECT_NAME = '<PROJECT_ID>' # @param {type:'string'}\n",
        "\n",
        "options.view_as(GoogleCloudOptions).project = PROJECT_NAME\n",
        "\n",
        "# Provide required pipeline options for the Dataflow Runner.\n",
        "options.view_as(StandardOptions).runner = \"DataflowRunner\"\n",
        "\n",
        "# Set the Google Cloud region that you want to run Dataflow in.\n",
        "options.view_as(GoogleCloudOptions).region = 'us-central1'\n",
        "\n",
        "# IMPORTANT: Replace BUCKET_NAME with the name of your Cloud Storage bucket.\n",
        "dataflow_gcs_location = \"gs://%s/dataflow\" % BUCKET_NAME\n",
        "\n",
        "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n",
        "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n",
        "\n",
        "\n",
        "# The Dataflow staging location. This location is used to stage the Dataflow pipeline and the SDK binary.\n",
        "options.view_as(GoogleCloudOptions).staging_location = '%s/staging' % dataflow_gcs_location\n",
        "\n",
        "# The Dataflow temp location. This location is used to store temporary files or intermediate results before outputting to the sink.\n",
        "options.view_as(GoogleCloudOptions).temp_location = '%s/temp' % dataflow_gcs_location\n",
        "\n",
        "# Enable GPU runtime. Make sure to enable 5xx driver since vLLM only works with 5xx drivers, not 4xx\n",
        "options.view_as(GoogleCloudOptions).dataflow_service_options = [\"worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx\"]\n",
        "\n",
        "options.view_as(SetupOptions).save_main_session = True\n",
        "\n",
        "# Choose a machine type compatible with GPU type\n",
        "options.view_as(WorkerOptions).machine_type = \"n1-standard-4\"\n",
        "\n",
        "options.view_as(WorkerOptions).sdk_container_image = CONTAINER_IMAGE"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xPhe597P1-QJ"
      },
      "source": [
        "Next, authenticate Colab so that it can to submit a job on your behalf."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xkf6yIVlFB8-"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "def auth_to_colab():\n",
        "  from google.colab import auth\n",
        "  auth.authenticate_user()\n",
        "\n",
        "auth_to_colab()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MJtEI6Ux2eza"
      },
      "source": [
        "Finally, run the pipeline on Dataflow. The pipeline definition is almost exactly the same as the definition used for local execution. The pipeline options are the only change to the pipeline.\n",
        "\n",
        "The following code creates a Dataflow job in your project. You can view the results in Colab or in the Google Cloud console. Creating a Dataflow job and downloading the model might take a few minutes. After the job starts performing inference, it quickly runs through the inputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8gjDdru_9Dii"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "import logging\n",
        "from apache_beam.ml.inference.base import RunInference\n",
        "from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler\n",
        "from apache_beam.ml.inference.base import PredictionResult\n",
        "import apache_beam as beam\n",
        "\n",
        "class FormatOutput(beam.DoFn):\n",
        "  def process(self, element, *args, **kwargs):\n",
        "    yield \"Input: {input}, Output: {output}\".format(input=element.example, output=element.inference)\n",
        "\n",
        "logging.getLogger().setLevel(logging.INFO)  # Output additional Dataflow Job metadata and launch logs. \n",
        "prompts = [\n",
        "    \"Hello, my name is\",\n",
        "    \"The president of the United States is\",\n",
        "    \"The capital of France is\",\n",
        "    \"The future of AI is\",\n",
        "    \"John cena is\",\n",
        "]\n",
        "\n",
        "# Specify the model handler, providing a path and the custom inference function.\n",
        "model_handler = VLLMCompletionsModelHandler('facebook/opt-125m')\n",
        "\n",
        "with beam.Pipeline(options=options) as p:\n",
        "  _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n",
        "         | RunInference(model_handler) # Send the prompts to the model and get responses.\n",
        "         | beam.ParDo(FormatOutput()) # Format the output.\n",
        "         | beam.Map(logging.info) # Print the formatted output.\n",
        "  )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "22cEHPCc28fH"
      },
      "source": [
        "## Run vLLM with a Gemma model\n",
        "\n",
        "After you configure your pipeline, switching the model used by the pipeline is relatively straightforward. You can run the same pipeline, but switch the model name defined in the model handler. This example runs the pipeline created previously but uses a [Gemma](https://ai.google.dev/gemma) model.\n",
        "\n",
        "Before you start, sign in to HuggingFace, and make sure that you can access the Gemma models. To access Gemma models, you must accept the terms and conditions.\n",
        "\n",
        "1.   Navigate to the [Gemma Model Card](https://huggingface.co/google/gemma-2b).\n",
        "2.   Sign in, or sign up for a free HuggingFace account.\n",
        "3.   Follow the prompts to agree to the conditions\n",
        "\n",
        "When you complete these steps, the following message appears on the model card page: `You have been granted access to this model`.\n",
        "\n",
        "Next, sign in to your account from this notebook by running the following code and then following the prompts."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JHwIsFI9kd9j"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "! huggingface-cli login"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IjX2If8rnCol"
      },
      "source": [
        "Verify that the notebook can now access the Gemma model. Run the following code, which starts a vLLM server to serve the Gemma 2b model. Because the default T4 Colab runtime doesn't support the full data type precision needed to run Gemma models, the `--dtype=half` parameter is required.\n",
        "\n",
        "When successful, the following cell runs indefinitely. After it starts the server process, you can shut it down. When the server process starts, the Gemma 2b model is successfully downloaded, and the server is ready to serve traffic."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LH_oCFWMiwFs"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "! python -m vllm.entrypoints.openai.api_server --model google/gemma-2b --dtype=half"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "31BmdDUAn-SW"
      },
      "source": [
        "To run the pipeline in Apache Beam, run the following code. Update the `VLLMCompletionsModelHandler` object with the new parameters, which match the command from the previous cell. Reuse all of the pipeline logic from the previous pipelines."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DyC2ikXg237p"
      },
      "outputs": [{"output_type": "stream", "name": "stdout", "text": ["\n"]}],
      "source": [
        "model_handler = VLLMCompletionsModelHandler('google/gemma-2b', vllm_server_kwargs={'dtype': 'half'})\n",
        "\n",
        "with beam.Pipeline() as p:\n",
        "  _ = (p | beam.Create(prompts) # Create a PCollection of the prompts.\n",
        "         | RunInference(model_handler) # Send the prompts to the model and get responses.\n",
        "         | beam.ParDo(FormatOutput()) # Format the output.\n",
        "         | beam.Map(print) # Print the formatted output.\n",
        "  )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C6OYfub6ovFK"
      },
      "source": [
        "### Run Gemma on Dataflow\n",
        "\n",
        "As a next step, run this pipeline on Dataflow. Follow the same steps described in the \"Run remotely on Dataflow\" section of this page:\n",
        "\n",
        "1.   Construct a Dockerfile and push a new Docker image. You can use the same Dockerfile that you created previously, but you need to add a step to set your HuggingFace authentication key. In your Dockerfile, add the following line before the entrypoint:\n",
        "\n",
        "  ```\n",
        "  RUN python3 -c 'from huggingface_hub import HfFolder; HfFolder.save_token(\"<TOKEN>\")'\n",
        "  ```\n",
        "\n",
        "2.   Set pipeline options. You can reuse the options defined in this notebook. Replace the Docker image location with your new Docker image.\n",
        "3.   Run the pipeline. Copy the pipeline that you ran on Dataflow, and replace the pipeline options with the pipeline options that you just defined.\n",
        "\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
